import numpy as np
import random
from hypersense.sampler.base_sampler import BaseSampler


class DistributionMatchSampler(BaseSampler):
    """
    A sampler that selects a subset of data such that the feature distribution
    approximates the original dataset. Compares using histogram-based L2 distance.
    """

    def __init__(self, dataset, sample_size, seed=42, num_candidates=100, feature_indices=None, **kwargs):
        super().__init__(dataset, sample_size, seed, **kwargs)
        self.num_candidates = num_candidates
        self.feature_indices = feature_indices  # If None, use all numeric columns

    def sample(self):
        random.seed(self.seed)
        np.random.seed(self.seed)

        dataset = np.array(self.dataset)
        total_size = len(dataset)

        if self.sample_size > total_size:
            raise ValueError(f"Sample size ({self.sample_size}) exceeds dataset size ({total_size}).")

        # Use numeric feature indices
        if self.feature_indices is None:
            self.feature_indices = list(range(dataset.shape[1] - 1))  # exclude last column if it's a label

        # Compute reference histograms for selected features
        ref_histograms = self._compute_feature_histograms(dataset, self.feature_indices)

        # Try multiple candidate subsets
        best_subset = None
        best_score = float("inf")

        for _ in range(self.num_candidates):
            indices = np.random.choice(total_size, self.sample_size, replace=False)
            candidate = dataset[indices]
            cand_hist = self._compute_feature_histograms(candidate, self.feature_indices)
            score = self._histogram_l2_distance(ref_histograms, cand_hist)

            if score < best_score:
                best_score = score
                best_subset = candidate

        return best_subset.tolist()

    def _compute_feature_histograms(self, data, feature_indices, bins=10):
        histograms = []
        for i in feature_indices:
            hist, _ = np.histogram(data[:, i], bins=bins, density=True)
            histograms.append(hist)
        return histograms

    def _histogram_l2_distance(self, hists1, hists2):
        total = 0.0
        for h1, h2 in zip(hists1, hists2):
            total += np.linalg.norm(h1 - h2)
        return total
